{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tce3stUlHN0L"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "tuOe1ymfHZPu"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qFdPvlXBOdUN"
      },
      "source": [
        "# Estimator"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MfBg1C5NB3X0"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td><a target=\"_blank\" href=\"https://tensorflow.google.cn/guide/estimator\" class=\"\"><img src=\"https://tensorflow.google.cn/images/tf_logo_32px.png\" class=\"\">在 TensorFlow.org 上查看</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/zh-cn/guide/estimator.ipynb\" class=\"\"><img src=\"https://tensorflow.google.cn/images/colab_logo_32px.png\" class=\"\">在 Google Colab 中运行</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/zh-cn/guide/estimator.ipynb\" class=\"\"><img src=\"https://tensorflow.google.cn/images/GitHub-Mark-32px.png\" class=\"\">在 GitHub 上查看源代码</a></td>\n",
        "  <td><a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/zh-cn/guide/estimator.ipynb\" class=\"\"><img src=\"https://tensorflow.google.cn/images/download_logo_32px.png\" class=\"\">下载笔记本</a></td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oEinLJt2Uowq"
      },
      "source": [
        "本文档介绍了 `tf.estimator`，它是一种高级 TensorFlow API。Estimator 封装了以下操作：\n",
        "\n",
        "- 训练\n",
        "- 评估\n",
        "- 预测\n",
        "- 导出以供使用\n",
        "\n",
        "您可以使用我们提供的预制 Estimator 或编写您自己的自定义 Estimator。所有 Estimator（无论是预制还是自定义）都是基于 `tf.estimator.Estimator` 类的类。\n",
        "\n",
        "有关简单示例，请查看 [Estimator 教程](https://render.githubusercontent.com/tutorials/estimator/linear.ipynb)。有关 API 设计概述，请参阅[白皮书](https://arxiv.org/abs/1708.02637)。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Wg5zbBliQvNL"
      },
      "source": [
        "## 优势\n",
        "\n",
        "与 `tf.keras.Model` 类似，`estimator` 是模型级别的抽象。`tf.estimator` 提供了一些目前仍在为 `tf.keras` 开发中的功能。包括：\n",
        "\n",
        "- 基于参数服务器的训练\n",
        "- 完整的 [TFX](http://tensorflow.google.cn/tfx) 集成"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yQ8fQYt_VD5E"
      },
      "source": [
        "## Estimator 功能\n",
        "\n",
        "Estimator 提供了以下优势：\n",
        "\n",
        "- 您可以在本地主机上或分布式多服务器环境中运行基于 Estimator 的模型，而无需更改模型。此外，您还可以在 CPU、GPU 或 TPU 上运行基于 Estimator 的模型，而无需重新编码模型。\n",
        "- Estimator 提供了安全的分布式训练循环，可控制如何以及何时进行以下操作：\n",
        "    - 加载数据\n",
        "    - 处理异常\n",
        "    - 创建检查点文件并从故障中恢复\n",
        "    - 保存 TensorBoard 摘要\n",
        "\n",
        "在用 Estimator 编写应用时，您必须将数据输入流水线与模型分离。这种分离简化了使用不同数据集进行的实验。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sXNBeY-oVxGQ"
      },
      "source": [
        "## 预制 Estimator\n",
        "\n",
        "使用预制 Estimator，您能够在比基础 TensorFlow API 高很多的概念层面上工作。您无需再担心创建计算图或会话，因为 Estimator 会替您完成所有“基础工作”。此外，使用预制 Estimator，您只需改动较少代码就能试验不同的模型架构。例如，`tf.estimator.DNNClassifier` 是一个预制 Estimator 类，可基于密集的前馈神经网络对分类模型进行训练。\n",
        "\n",
        "### 预制 Estimator 程序结构\n",
        "\n",
        "依赖于预制 Estimator 的 TensorFlow 程序通常包括以下四个步骤：\n",
        "\n",
        "#### 1. 编写一个或多个数据集导入函数。\n",
        "\n",
        "例如，您可以创建一个函数来导入训练集，创建另一个函数来导入测试集。每个数据集导入函数必须返回以下两个对象：\n",
        "\n",
        "- 字典，其中键是特征名称，值是包含相应特征数据的张量（或 SparseTensor）\n",
        "- 包含一个或多个标签的张量\n",
        "\n",
        "例如，以下代码展示了输入函数的基本框架：\n",
        "\n",
        "```\n",
        "def input_fn(dataset):     ...  # manipulate dataset, extracting the feature dict and the label     return feature_dict, label\n",
        "```\n",
        "\n",
        "有关详细信息，请参阅[数据指南](https://render.githubusercontent.com/guide/data.md)。\n",
        "\n",
        "#### 2. 定义特征列。\n",
        "\n",
        "每个 `tf.feature_column` 标识了特征名称、特征类型，以及任何输入预处理。例如，以下代码段创建了三个包含整数或浮点数据的特征列。前两个特征列仅标识了特征的名称和类型。第三个特征列还指定了一个会被程序调用以缩放原始数据的 lambda：\n",
        "\n",
        "```\n",
        "# Define three numeric feature columns. population = tf.feature_column.numeric_column('population') crime_rate = tf.feature_column.numeric_column('crime_rate') median_education = tf.feature_column.numeric_column(   'median_education',   normalizer_fn=lambda x: x - global_education_mean)\n",
        "```\n",
        "\n",
        "有关详细信息，请参阅[特征列教程](https://tensorflow.google.cn/tutorials/keras/feature_columns)。\n",
        "\n",
        "#### 3. 实例化相关预制 Estimator。\n",
        "\n",
        "例如，下面是对名为 `LinearClassifier` 的预制 Estimator 进行实例化的示例：\n",
        "\n",
        "```\n",
        "# Instantiate an estimator, passing the feature columns. estimator = tf.estimator.LinearClassifier(   feature_columns=[population, crime_rate, median_education])\n",
        "```\n",
        "\n",
        "有关详细信息，请参阅[线性分类器教程](https://tensorflow.google.cn/tutorials/estimator/linear)。\n",
        "\n",
        "#### 4. 调用训练、评估或推断方法。\n",
        "\n",
        "例如，所有 Estimator 都会提供一个用于训练模型的 `train` 方法。\n",
        "\n",
        "```\n",
        "# `input_fn` is the function created in Step 1 estimator.train(input_fn=my_training_set, steps=2000)\n",
        "```\n",
        "\n",
        "您可以在下面看到与此相关的示例。\n",
        "\n",
        "### 预制 Estimator 的优势\n",
        "\n",
        "预制 Estimator 对最佳做法进行了编码，具有以下优势：\n",
        "\n",
        "- 确定计算图不同部分的运行位置，以及在单台机器或集群上实施策略的最佳做法。\n",
        "- 事件（摘要）编写和通用摘要的最佳做法。\n",
        "\n",
        "如果不使用预制 Estimator，则您必须自己实现上述功能。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oIaPjYgnZdn6"
      },
      "source": [
        "## 自定义 Estimator\n",
        "\n",
        "每个 Estimator（无论预制还是自定义）的核心是其*模型函数*，这是一种为训练、评估和预测构建计算图的方法。当您使用预制 Estimator 时，已经有人为您实现了模型函数。当使用自定义 Estimator 时，您必须自己编写模型函数。\n",
        "\n",
        "## 推荐工作流\n",
        "\n",
        "1. 假设存在一个合适的预制 Estimator，用它构建您的第一个模型，并将其结果作为基准。\n",
        "2. 使用此预制 Estimator 构建并测试您的整个流水线，包括数据的完整性和可靠性。\n",
        "3. 如果有其他合适的预制 Estimator，可通过运行实验确定哪个预制 Estimator 能够生成最佳结果。\n",
        "4. 如果可能，您可以通过构建自己的自定义 Estimator 进一步改进模型。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kRr7DGZxFApM"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IqR2PQG4ZaZ0"
      },
      "outputs": [],
      "source": [
        "import tensorflow_datasets as tfds\n",
        "tfds.disable_progress_bar()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "P7aPNnXUbN4j"
      },
      "source": [
        "## 从 Keras 模型创建 Estimator\n",
        "\n",
        "您可以使用 `tf.keras.estimator.model_to_estimator` 将现有的 Keras 模型转换为 Estimator。这样一来，您的 Keras 模型就可以利用 Estimator 的优势，例如分布式训练。\n",
        "\n",
        "实例化 Keras MobileNet V2 模型并用训练中使用的优化器、损失和指标来编译模型："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XE6NMcuGeDOP"
      },
      "outputs": [],
      "source": [
        "keras_mobilenet_v2 = tf.keras.applications.MobileNetV2(\n",
        "    input_shape=(160, 160, 3), include_top=False)\n",
        "keras_mobilenet_v2.trainable = False\n",
        "\n",
        "estimator_model = tf.keras.Sequential([\n",
        "    keras_mobilenet_v2,\n",
        "    tf.keras.layers.GlobalAveragePooling2D(),\n",
        "    tf.keras.layers.Dense(1)\n",
        "])\n",
        "\n",
        "# Compile the model\n",
        "estimator_model.compile(\n",
        "    optimizer='adam',\n",
        "    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),\n",
        "    metrics=['accuracy'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A3hcxzcEfYfX"
      },
      "source": [
        "从已编译的 Keras 模型创建 `Estimator`。Keras 模型的初始模型状态会保留在已创建的 `Estimator`中："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UCSSifirfyHk"
      },
      "outputs": [],
      "source": [
        "est_mobilenet_v2 = tf.keras.estimator.model_to_estimator(keras_model=estimator_model)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8jRNRVb_fzGT"
      },
      "source": [
        "您可以像对待任何其他 `Estimator` 一样对待派生的 `Estimator`。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Rv9xJk51e1fB"
      },
      "outputs": [],
      "source": [
        "IMG_SIZE = 160  # All images will be resized to 160x160\n",
        "\n",
        "def preprocess(image, label):\n",
        "  image = tf.cast(image, tf.float32)\n",
        "  image = (image/127.5) - 1\n",
        "  image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))\n",
        "  return image, label"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Fw8OjwujVBkc"
      },
      "outputs": [],
      "source": [
        "def train_input_fn(batch_size):\n",
        "  data = tfds.load('cats_vs_dogs', as_supervised=True)\n",
        "  train_data = data['train']\n",
        "  train_data = train_data.map(preprocess).shuffle(500).batch(batch_size)\n",
        "  return train_data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JMb0cuy0gbTi"
      },
      "source": [
        "要进行训练，可调用 Estimator 的训练函数："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4JsvMp8Jge80"
      },
      "outputs": [],
      "source": [
        "est_mobilenet_v2.train(input_fn=lambda: train_input_fn(32), steps=500)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jvr_rAzngY9v"
      },
      "source": [
        "同样，要进行评估，可调用 Estimator 的评估函数："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kVNPqysQgYR2"
      },
      "outputs": [],
      "source": [
        "est_mobilenet_v2.evaluate(input_fn=lambda: train_input_fn(32), steps=10)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5HeTOvCYbjZb"
      },
      "source": [
        "有关详细信息，请参阅 `tf.keras.estimator.model_to_estimator` 文档。"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [
        "Tce3stUlHN0L"
      ],
      "name": "estimator.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
